{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "27d5425deb10849c",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "#  第三章：编写注意力机制"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "755ce6dff684c41",
   "metadata": {
    "collapsed": false
   },
   "source": [
    " 在这个notebook中使用的包有："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e58f33e8-5dc9-4dd5-ab84-5a011fa11d92",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:05:12.683615Z",
     "start_time": "2024-03-01T07:05:09.675943900Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch version: 2.1.1\n"
     ]
    }
   ],
   "source": [
    "from importlib.metadata import version\n",
    "import torch\n",
    "\n",
    "print(\"torch version:\", version(\"torch\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d0475ea32ec926b",
   "metadata": {
    "collapsed": false
   },
   "source": [
    " ## 3.1 长序列建模的问题"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "929f224b96fb1a27",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "- 这个部分没有代码。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81f7a179e9cf96c7",
   "metadata": {
    "collapsed": false
   },
   "source": [
    " ## 3.2 使用注意力机制捕获数据依赖性"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6fde64c-6034-421d-81d9-8244932086ea",
   "metadata": {},
   "source": [
    "- 这个部分没有代码。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a6fc49dd41e1c19",
   "metadata": {
    "collapsed": false
   },
   "source": [
    " ## 3.3 使用自注意力关注输入的不同部分"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2bf1532f595d316",
   "metadata": {
    "collapsed": false
   },
   "source": [
    " ### 3.3.1 一个简单的自注意力机制，不包含可训练权重"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66cdd1ec5345c45e",
   "metadata": {
    "collapsed": false
   },
   "source": [
    " - 本部分介绍了一个极其简化的自注意力机制版本，它不包含任何可训练的权重。这只是为了说明概念，并不是实际在 transformers 模型中使用的注意力机制。接下来的3.3.2节将扩展这个简单的注意力机制，介绍真正的自注意力机制。\n",
    "- 假设我们有一个输入序列，从 $x^{(1)}$ 到 $x^{(T)}$。\n",
    "  - 输入是一段文本（比如句子 \"Your journey starts with one step\"），它已经被转换成了第二章中描述的标记嵌入形式。\n",
    "  - 例如，$x^{(1)}$ 是一个 d 维的向量，代表了单词 \"Your\"，依此类推。\n",
    "- **目标：** 为输入序列中的每个元素 $x^{(i)}$（从 $x^{(1)}$ 到 $x^{(T)}$，其中 $z$ 和 $x$ 的维度相同）计算上下文向量 $z^{(i)}$。\n",
    "    - 上下文向量 $z^{(i)}$ 是对输入 $x^{(1)}$ 到 $x^{(T)}$ 的加权平均。\n",
    "    - 上下文向量是针对特定输入的上下文信息。\n",
    "      - 我们不使用 $x^{(i)}$ 作为任意输入标记的占位符，而是考虑第二个输入，$x^{(2)}$。\n",
    "      - 为了具体说明，我们不是用占位符 $z^{(i)}$，而是考虑第二个输出的上下文向量，$z^{(2)}$。\n",
    "      - 第二个上下文向量 $z^{(2)}$ 是对所有输入 $x^{(1)}$ 到 $x^{(T)}$ 的加权平均，权重是根据第二个输入元素 $x^{(2)}$ 来确定的。这些注意力权重决定了在计算 $z^{(2)}$ 时，每个输入元素对最终加权平均的贡献程度。\n",
    "    - 简而言之，可以把 $z^{(2)}$ 看作是 $x^{(2)}$ 的一个变体，它不仅包含了 $x^{(2)}$ 的信息，还融合了与当前任务相关的所有其他输入元素的信息。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e89766c8b4d562a1",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "- 按照惯例，未经归一化的注意力权重被称为**“注意力分数”**，而归一化后的注意力分数（它们的和为1）被称为**“注意力权重”**。\n",
    "\n",
    "- 注意力权重和上下文向量的计算总结在下面的图表中："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28531e83-85bd-43a4-8928-57bb0372d9c7",
   "metadata": {},
   "source": [
    "<img src=\"figures/attention.png\" width=\"600px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12dfcc8a4890c11f",
   "metadata": {
    "collapsed": false
   },
   "source": [
    " - 下面的代码逐步展示了上面图表的内容。\n",
    "\n",
    "<br>\n",
    "\n",
    "- **步骤 1：** 计算未归一化的注意力分数 $\\omega$。\n",
    "  - 假设我们使用第二个输入标记作为查询，即 $q^{(2)} = x^{(2)}$，我们通过点积来计算未归一化的注意力分数：\n",
    "    - $\\omega_{21} = x^{(1)} \\cdot q^{(2)\\top}$\n",
    "    - $\\omega_{22} = x^{(2)} \\cdot q^{(2)\\top}$\n",
    "    - $\\omega_{23} = x^{(3)} \\cdot q^{(2)\\top}$\n",
    "    - ...\n",
    "    - $\\omega_{2T} = x^{(T)} \\cdot q^{(2)\\top}$\n",
    "  - 在这里，$\\omega$ 是希腊字母 \"omega\"，用来表示未归一化的注意力分数。\n",
    "    - $\\omega_{21}$ 中的下标 \"21\" 表示输入序列的第2个元素被用作查询，与输入序列的第1个元素进行比较。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e29440f-9b77-4966-83aa-d1ff2e653b00",
   "metadata": {},
   "source": [
    "<img src=\"figures/dot-product.png\" width=\"450px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0c28811e45031fd",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "- 假设我们有以下已经转换成3维向量的输入句子，如第三章所述（为了说明方便，这里使用了一个非常小的嵌入维度，以便在不换行的情况下适应页面）："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "22b9556a-aaf8-4ab4-a5b4-973372b0b2c3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:05:23.298987200Z",
     "start_time": "2024-03-01T07:05:23.284983400Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "inputs = torch.tensor(\n",
    "  [[0.43, 0.15, 0.89], # Your     (x^1)\n",
    "   [0.55, 0.87, 0.66], # journey  (x^2)\n",
    "   [0.57, 0.85, 0.64], # starts   (x^3)\n",
    "   [0.22, 0.58, 0.33], # with     (x^4)\n",
    "   [0.77, 0.25, 0.10], # one      (x^5)\n",
    "   [0.05, 0.80, 0.55]] # step     (x^6)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c98cfa901b290ae",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "- 我们以输入序列中的第二个元素 $x^{(2)}$ 为例，来计算上下文向量 $z^{(2)}$；在后面的部分，我们将推广这个方法来计算所有的上下文向量。\n",
    "- 第一步是计算未归一化的注意力分数，通过计算查询 $x^{(2)}$ 与所有其他输入标记之间的点积来实现："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1540227deed6d1da",
   "metadata": {
    "collapsed": false
   },
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6fb5b2f8-dd2c-4a6d-94ef-a0e9ad163951",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:05:25.477473800Z",
     "start_time": "2024-03-01T07:05:25.459470Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])\n"
     ]
    }
   ],
   "source": [
    "# 从输入序列中取出第二个元素作为查询向量\n",
    "query = inputs[1]\n",
    "\n",
    "# 创建一个空的张量来存储注意力分数，其形状与输入序列的批次大小相同\n",
    "attn_scores_2 = torch.empty(inputs.shape[0])\n",
    "\n",
    "# 遍历输入序列的每个元素\n",
    "for i, x_i in enumerate(inputs):\n",
    "    # 计算当前元素与查询向量的点积作为注意力分数\n",
    "    # 这里不需要转置，因为假设输入向量是一维的\n",
    "    attn_scores_2[i] = torch.dot(x_i, query)\n",
    "\n",
    "# 打印注意力分数\n",
    "print(attn_scores_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebbb77f59671f0aa",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "- 注：点积实际上是一种简写，它表示将两个向量的对应元素相乘，然后将这些乘积相加求和："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ef00c65e10fddb4",
   "metadata": {
    "collapsed": false
   },
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9842f39b-1654-410e-88bf-d1b899bf0241",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:05:27.317884600Z",
     "start_time": "2024-03-01T07:05:27.292879100Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.9544)\n",
      "tensor(0.9544)\n"
     ]
    }
   ],
   "source": [
    "# 初始化结果变量为0\n",
    "res = 0.\n",
    "\n",
    "# 遍历输入序列的第一个元素中的每个元素\n",
    "for idx, element in enumerate(inputs[0]):\n",
    "    # 将当前元素与查询向量的对应元素相乘，并将结果累加到res\n",
    "    res += inputs[0][idx] * query[idx]\n",
    "\n",
    "# 打印手动计算的点积结果\n",
    "print(res)\n",
    "\n",
    "# 使用PyTorch的torch.dot函数计算点积，并打印结果\n",
    "print(torch.dot(inputs[0], query))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "389d1ba5c3db582b",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "- **第二步：** 对未归一化的注意力分数（称为“omegas”，用希腊字母 $\\omega$ 表示）进行归一化处理，使得它们的总和等于1。\n",
    "- 这里有一个简单的方法来归一化这些未归一化的注意力分数，以确保它们的总和为1（这是一个常用的做法，有助于理解，并且对训练过程的稳定性至关重要）："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffd31c47c7645e04",
   "metadata": {
    "collapsed": false
   },
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e3ccc99c-33ce-4f11-b7f2-353cf1cbdaba",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:05:29.228312Z",
     "start_time": "2024-03-01T07:05:29.180302Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])\n",
      "Sum: tensor(1.0000)\n"
     ]
    }
   ],
   "source": [
    "# 使用注意力分数的和对注意力分数进行归一化处理\n",
    "attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()\n",
    "\n",
    "# 打印归一化后的注意力权重\n",
    "print(\"Attention weights:\", attn_weights_2_tmp)\n",
    "# 验证归一化后的注意力权重之和是否为1\n",
    "print(\"Sum:\", attn_weights_2_tmp.sum())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f085c4759c607872",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "- 然而，在实际应用中，使用softmax函数进行归一化更为常见且推荐，因为它更擅长处理极端值，在训练过程中具有更理想的梯度特性。\n",
    "- 以下是一个简单的softmax函数实现，它用于缩放，同时归一化向量元素，使得它们的总和为1："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "07b2e58d-a6ed-49f0-a1cd-2463e8d53a20",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:05:30.924690900Z",
     "start_time": "2024-03-01T07:05:30.877680700Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])\n",
      "Sum: tensor(1.)\n"
     ]
    }
   ],
   "source": [
    "# 定义一个简单的 softmax 函数实现\n",
    "def softmax_naive(x):\n",
    "    # 对输入张量 x 的每个元素应用指数函数\n",
    "    exp_x = torch.exp(x)\n",
    "    # 计算 exp_x 在指定维度（这里是第一个维度，dim=0）上的和\n",
    "    sum_exp_x = exp_x.sum(dim=0)\n",
    "    # 将 exp_x 的每个元素除以它们的和，得到 softmax 结果\n",
    "    return exp_x / sum_exp_x\n",
    "\n",
    "# 使用 naive softmax 函数对注意力分数进行归一化\n",
    "attn_weights_2_naive = softmax_naive(attn_scores_2)\n",
    "\n",
    "# 打印归一化后的注意力权重\n",
    "print(\"Attention weights:\", attn_weights_2_naive)\n",
    "# 验证归一化后的注意力权重之和是否为1\n",
    "print(\"Sum:\", attn_weights_2_naive.sum())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10bbc97e55cdbd1c",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "- 上面的简单实现可能会因为输入值过大或过小而导致数值不稳定问题，这主要是因为数值溢出和下溢的问题。\n",
    "- 因此，在实际应用中，建议使用 PyTorch 提供的 `softmax` 函数实现，它经过高度优化，性能更优："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2d99cac4-45ea-46b3-b3c1-e000ad16e158",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:05:32.360025300Z",
     "start_time": "2024-03-01T07:05:32.332020Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])\n",
      "Sum: tensor(1.)\n"
     ]
    }
   ],
   "source": [
    "# 使用 PyTorch 的 softmax 函数对注意力分数进行归一化\n",
    "# dim=0 表示在第一个维度（通常是特征维度）上进行 softmax 计算\n",
    "attn_weights_2 = torch.softmax(attn_scores_2, dim=0)\n",
    "\n",
    "# 打印归一化后的注意力权重\n",
    "print(\"Attention weights:\", attn_weights_2)\n",
    "# 验证归一化后的注意力权重之和是否为1\n",
    "print(\"Sum:\", attn_weights_2.sum())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26834a0afec960d6",
   "metadata": {
    "collapsed": false
   },
   "source": [
    " - **第三步**：通过将嵌入的输入标记 $x^{(i)}$ 与注意力权重相乘，然后将得到的结果向量相加，来计算上下文向量 $z^{(2)}$："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8fcb96f0-14e5-4973-a50e-79ea7c6af99f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:05:33.773341800Z",
     "start_time": "2024-03-01T07:05:33.765339800Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.4419, 0.6515, 0.5683])\n"
     ]
    }
   ],
   "source": [
    "# 选择输入序列中的第二个元素作为查询向量\n",
    "query = inputs[1]\n",
    "\n",
    "# 初始化上下文向量，其形状与查询向量相同，初始值为0\n",
    "context_vec_2 = torch.zeros(query.shape)\n",
    "\n",
    "# 遍历输入序列中的每个元素\n",
    "for i, x_i in enumerate(inputs):\n",
    "    # 累加每个输入元素与其对应的注意力权重的乘积\n",
    "    context_vec_2 += attn_weights_2[i] * x_i\n",
    "\n",
    "# 打印计算得到的上下文向量\n",
    "print(context_vec_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16b7a0e40c6d8d08",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 3.3.2 计算所有输入标记的注意力权重"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5bfcbe08825a085b",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "#### 推广到所有输入序列标记：\n",
    "\n",
    "- 在上面的内容中，我们计算了输入2的注意力权重和上下文向量（如下面图表中高亮行所示）。\n",
    "- 接下来，我们将推广这个计算过程，以计算所有输入标记的注意力权重和上下文向量。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11c0fb55-394f-42f4-ba07-d01ae5c98ab4",
   "metadata": {},
   "source": [
    "<img src=\"figures/attention-matrix.png\" width=\"400px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8dcd2e858df4af2",
   "metadata": {
    "collapsed": false
   },
   "source": [
    " - 应用之前的**第一步**，对所有成对的元素进行计算，以得到未归一化的注意力分数矩阵："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "04004be8-07a1-468b-ab33-32e16a551b45",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:05:36.389926600Z",
     "start_time": "2024-03-01T07:05:36.370922Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],\n",
      "        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],\n",
      "        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],\n",
      "        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],\n",
      "        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],\n",
      "        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])\n"
     ]
    }
   ],
   "source": [
    "# 创建一个 6x6 的零张量，用于存储注意力分数\n",
    "attn_scores = torch.empty(6, 6)\n",
    "\n",
    "# 遍历输入序列中的每个元素\n",
    "for i, x_i in enumerate(inputs):\n",
    "    # 对于当前的输入元素 x_i，再次遍历整个输入序列\n",
    "    for j, x_j in enumerate(inputs):\n",
    "        # 计算 x_i 和 x_j 的点积，作为注意力分数，并存储在 attn_scores 矩阵的对应位置\n",
    "        attn_scores[i, j] = torch.dot(x_i, x_j)\n",
    "\n",
    "# 打印完整的注意力分数矩阵\n",
    "print(attn_scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4a64a8236579473",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "- 我们可以通过矩阵乘法更有效地实现上面的计算："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2cea69d0-9a47-45da-8d5a-47ceef2df673",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:05:39.111548400Z",
     "start_time": "2024-03-01T07:05:39.099546300Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],\n",
      "        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],\n",
      "        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],\n",
      "        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],\n",
      "        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],\n",
      "        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])\n"
     ]
    }
   ],
   "source": [
    "# 使用矩阵乘法计算输入序列的点积矩阵\n",
    "# inputs @ inputs.T 相当于 inputs 与 inputs 的转置相乘\n",
    "attn_scores = inputs @ inputs.T\n",
    "\n",
    "# 打印注意力分数矩阵\n",
    "print(attn_scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "277f7ce6c43bf3af",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "- 与之前的**第二步**类似，我们对每一行进行归一化，以使每一行的值之和为1："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "fa4ef062-de81-47ee-8415-bfe1708c81b8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:05:40.860940800Z",
     "start_time": "2024-03-01T07:05:40.843935200Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],\n",
      "        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],\n",
      "        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],\n",
      "        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],\n",
      "        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],\n",
      "        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])\n"
     ]
    }
   ],
   "source": [
    "attn_weights = torch.softmax(attn_scores, dim=1)\n",
    "print(attn_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd1207b1f9b38e9c",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "- 快速验证每一行的值确实之和为1："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "112b492c-fb6f-4e6d-8df5-518ae83363d5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:05:42.340270Z",
     "start_time": "2024-03-01T07:05:42.308262700Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Row 2 sum: 1.0\n",
      "All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])\n"
     ]
    }
   ],
   "source": [
    "# 定义第二行的注意力权重列表\n",
    "row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])\n",
    "# 打印第二行的和\n",
    "print(\"Row 2 sum:\", row_2_sum)\n",
    "\n",
    "# 使用 PyTorch 的 sum 函数沿着指定维度（这里是维度1，即行）计算所有行的和\n",
    "# attn_weights.sum(dim=1) 会返回一个包含每行和的一维张量\n",
    "print(\"All row sums:\", attn_weights.sum(dim=1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e9e3585324e487a",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "- 应用之前的**第三步**，计算所有上下文向量："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "ba8eafcf-f7f7-4989-b8dc-61b50c4f81dc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:05:43.990639200Z",
     "start_time": "2024-03-01T07:05:43.971634200Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.4421, 0.5931, 0.5790],\n",
      "        [0.4419, 0.6515, 0.5683],\n",
      "        [0.4431, 0.6496, 0.5671],\n",
      "        [0.4304, 0.6298, 0.5510],\n",
      "        [0.4671, 0.5910, 0.5266],\n",
      "        [0.4177, 0.6503, 0.5645]])\n"
     ]
    }
   ],
   "source": [
    "all_context_vecs = attn_weights @ inputs\n",
    "print(all_context_vecs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "298b13b5bb3d62d1",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "- 作为合理性检查，之前计算的上下文向量 $z^{(2)} = [0.4419, 0.6515, 0.5683]$ 可以在上面的第二行找到："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2570eb7d-aee1-457a-a61e-7544478219fa",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:05:46.564249100Z",
     "start_time": "2024-03-01T07:05:46.548244600Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Previous 2nd context vector: tensor([0.4419, 0.6515, 0.5683])\n"
     ]
    }
   ],
   "source": [
    "print(\"Previous 2nd context vector:\", context_vec_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a303b6fb-9f7e-42bb-9fdb-2adabf0a6525",
   "metadata": {},
   "source": [
    "## 3.4 使用可训练权重实现自注意力"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b90a77e-d746-4704-9354-1ddad86e6298",
   "metadata": {},
   "source": [
    "### 3.4.1 逐步计算注意力权重"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46e95a46-1f67-4b71-9e84-8e2db84ab036",
   "metadata": {},
   "source": [
    "- 在本节中，我们正在实现自注意力机制，这是原始 transformer 架构、GPT模型以及大多数其他流行的大型语言模型（LLMs）中使用的技术。\n",
    "- 这种自注意力机制也被称为“缩放点积注意力”。\n",
    "- 总体思路与之前类似：\n",
    "  - 我们想要计算上下文向量，作为特定输入元素的输入向量的加权和。\n",
    "  - 为此，我们需要注意力权重。\n",
    "- 你将看到，与之前介绍的基本注意力机制相比，只有细微的差别：\n",
    "  - 最显著的区别是引入了在模型训练期间更新的权重矩阵。\n",
    "  - 这些可训练的权重矩阵至关重要，以便模型（特别是模型内的注意力模块）能够学习产生 \"good\" 上下文向量。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d996671-87aa-45c9-b2e0-07a7bcc9060a",
   "metadata": {},
   "source": [
    "- 在逐步实现自注意力机制的过程中，我们将首先引入三个训练权重矩阵 $W_q$、$W_k$ 和 $W_v$。\n",
    "- 这三个矩阵用于通过矩阵乘法将嵌入的输入标记 $x^{(i)}$ 投影到查询、键和值向量：\n",
    "\n",
    "  - 查询向量：$q^{(i)} = W_q \\cdot x^{(i)}$\n",
    "  - 键向量：$k^{(i)} = W_k \\cdot x^{(i)}$\n",
    "  - 值向量：$v^{(i)} = W_v \\cdot x^{(i)}$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3b29bc6-4bde-4924-9aff-0af1421803f5",
   "metadata": {},
   "source": [
    "<img src=\"figures/weight-selfattn-1.png\" width=\"600px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f334313-5fd0-477b-8728-04080a427049",
   "metadata": {},
   "source": [
    "- 输入 $x$ 的嵌入维度和查询向量 $q$ 的维度可以相同也可以不同，这取决于模型的设计和具体实现。\n",
    "- 在 GPT 模型中，输入和输出的维度通常是相同的，但为了更好地说明计算过程，我们在这里选择不同的输入和输出维度："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8250fdc6-6cd6-4c5b-b9c0-8c643aadb7db",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:06:05.548505500Z",
     "start_time": "2024-03-01T07:06:05.539503700Z"
    }
   },
   "outputs": [],
   "source": [
    "# 获取输入序列中的第二个元素作为特定的输入向量\n",
    "x_2 = inputs[1]\n",
    "\n",
    "# 获取输入张量的嵌入维度，这里假设输入向量的维度为3\n",
    "d_in = inputs.shape[1]\n",
    "\n",
    "# 设置输出嵌入的维度，这里假设输出向量的维度为2\n",
    "d_out = 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f528cfb3-e226-47dd-b363-cc2caaeba4bf",
   "metadata": {},
   "source": [
    "- 下面，我们初始化这三个权重矩阵；请注意，我们设置了 `requires_grad=False` 以减少输出中的混乱，这是为了说明目的。但是，如果我们要在模型训练中使用这些权重矩阵，我们会将 `requires_grad` 设置为 `True`，以便在模型训练过程中更新这些矩阵。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "bfd7259a-f26c-4cea-b8fc-282b5cae1e00",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:06:09.001276300Z",
     "start_time": "2024-03-01T07:06:08.956267200Z"
    }
   },
   "outputs": [],
   "source": [
    "# 设置随机种子以确保结果的可重复性\n",
    "torch.manual_seed(123)\n",
    "\n",
    "# 创建查询权重矩阵，形状为 (d_in, d_out)，并且设置 requires_grad=False 表示这些权重在训练过程中不会更新\n",
    "W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)\n",
    "\n",
    "# 创建键权重矩阵，形状和 W_query 相同，同样设置 requires_grad=False\n",
    "W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)\n",
    "\n",
    "# 创建值权重矩阵，形状和 W_query 相同，同样设置 requires_grad=False\n",
    "W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abfd0b50-7701-4adb-821c-e5433622d9c4",
   "metadata": {},
   "source": [
    "- 接下来，我们计算查询、键和值向量："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "73cedd62-01e1-4196-a575-baecc6095601",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:06:12.380032800Z",
     "start_time": "2024-03-01T07:06:12.363027500Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.4306, 1.4551])\n"
     ]
    }
   ],
   "source": [
    "# 使用查询权重矩阵 W_query 将第二个输入元素 x_2 投影到查询空间\n",
    "query_2 = x_2 @ W_query  # 使用 @ 运算符进行矩阵乘法\n",
    "\n",
    "# 使用键权重矩阵 W_key 将第二个输入元素 x_2 投影到键空间\n",
    "key_2 = x_2 @ W_key\n",
    "\n",
    "# 使用值权重矩阵 W_value 将第二个输入元素 x_2 投影到值空间\n",
    "value_2 = x_2 @ W_value\n",
    "\n",
    "# 打印计算得到的查询向量\n",
    "print(query_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9be308b3-aca3-421b-b182-19c3a03b71c7",
   "metadata": {},
   "source": [
    "- 从下面的结果可以看出，我们成功地将6个输入标记从一个3维空间投影到了一个2维嵌入空间："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8c1c3949-fc08-4d19-a41e-1c235b4e631b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:06:15.416710300Z",
     "start_time": "2024-03-01T07:06:15.403706900Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "keys.shape: torch.Size([6, 2])\n",
      "values.shape: torch.Size([6, 2])\n"
     ]
    }
   ],
   "source": [
    "# 使用键权重矩阵 W_key 将输入序列 inputs 投影到键空间\n",
    "keys = inputs @ W_key\n",
    "\n",
    "# 使用值权重矩阵 W_value 将输入序列 inputs 投影到值空间\n",
    "values = inputs @ W_value\n",
    "\n",
    "# 打印键向量的形状\n",
    "print(\"keys.shape:\", keys.shape)\n",
    "\n",
    "# 打印值向量的形状\n",
    "print(\"values.shape:\", values.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bac5dfd6-ade8-4e7b-b0c1-bed40aa24481",
   "metadata": {},
   "source": [
    "- 在下一步，即**第二步**中，我们通过计算查询向量与每个键向量之间的点积来计算未归一化的注意力分数："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ed0a2b7-5c50-4ede-90cf-7ad74412b3aa",
   "metadata": {},
   "source": [
    "<img src=\"figures/weight-selfattn-2.png\" width=\"600px\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "64cbc253-a182-4490-a765-246979ea0a28",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:06:19.620649700Z",
     "start_time": "2024-03-01T07:06:19.603645500Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(1.8524)\n"
     ]
    }
   ],
   "source": [
    "# 从键张量中提取第二个键向量，对应于输入序列中的第二个元素\n",
    "keys_2 = keys[1]  # Python 中的索引是从 0 开始的\n",
    "\n",
    "# 计算查询向量 query_2 和键向量 keys_2 之间的点积，得到注意力分数\n",
    "attn_score_22 = query_2.dot(keys_2)\n",
    "\n",
    "# 打印计算得到的注意力分数\n",
    "print(attn_score_22)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e9d15c0-c24e-4e6f-a160-6349b418f935",
   "metadata": {},
   "source": [
    "- 由于我们有6个输入，对于给定的查询向量，我们得到了6个注意力分数："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "b14e44b5-d170-40f9-8847-8990804af26d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:06:22.436278900Z",
     "start_time": "2024-03-01T07:06:22.426277700Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])\n"
     ]
    }
   ],
   "source": [
    "# 使用查询向量 query_2 和所有键向量 keys 的转置（keys.T）进行矩阵乘法，得到所有注意力分数\n",
    "attn_scores_2 = query_2 @ keys.T\n",
    "\n",
    "# 打印所有注意力分数\n",
    "print(attn_scores_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8622cf39-155f-4eb5-a0c0-82a03ce9b999",
   "metadata": {},
   "source": [
    "<img src=\"figures/weight-selfattn-3.png\" width=\"600px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1609edb-f089-461a-8de2-c20c1bb29836",
   "metadata": {},
   "source": [
    "- 接下来，在**第三步**中，我们使用之前提到的 softmax 函数来计算注意力权重（归一化的注意力分数，它们的总和为1）。\n",
    "- 与之前不同的是，我们现在通过除以嵌入维度的平方根，$\\sqrt{d_k}$（即 `d_k**0.5`），来缩放注意力分数："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "146f5587-c845-4e30-9894-c7ed3a248153",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:06:26.976293500Z",
     "start_time": "2024-03-01T07:06:26.972292100Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])\n"
     ]
    }
   ],
   "source": [
    "# 获取键向量的维度，即每个键向量的维度大小\n",
    "d_k = keys.shape[1]\n",
    "\n",
    "# 使用 softmax 函数对注意力分数进行归一化处理\n",
    "# attn_scores_2 / d_k**0.5 是缩放点积分数，防止过大的点积值导致 softmax 梯度过小\n",
    "# dim=-1 表示沿着最后一个维度（即注意力分数的维度）进行 softmax 计算\n",
    "attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)\n",
    "\n",
    "# 打印归一化后的注意力权重\n",
    "print(attn_weights_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8f61a28-b103-434a-aee1-ae7cbd821126",
   "metadata": {},
   "source": [
    "<img src=\"figures/weight-selfattn-4.png\" width=\"600px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1890e3f9-db86-4ab8-9f3b-53113504a61f",
   "metadata": {},
   "source": [
    "- 在**第四步**中，我们现在计算输入查询向量2的上下文向量："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "e138f033-fa7e-4e3a-8764-b53a96b26397",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:06:33.820850800Z",
     "start_time": "2024-03-01T07:06:33.810848200Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.3061, 0.8210])\n"
     ]
    }
   ],
   "source": [
    "# 使用归一化的注意力权重 attn_weights_2 和值向量 values 进行矩阵乘法，得到上下文向量\n",
    "context_vec_2 = attn_weights_2 @ values\n",
    "\n",
    "# 打印计算得到的上下文向量\n",
    "print(context_vec_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d7b2907-e448-473e-b46c-77735a7281d8",
   "metadata": {},
   "source": [
    "### 3.4.2 实现一个紧凑的 SelfAttention 类"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04313410-3155-4d90-a7a3-2f3386e73677",
   "metadata": {},
   "source": [
    "- 将所有内容整合起来，我们可以按照以下方式实现自注意力机制："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "51590326-cdbe-4e62-93b1-17df71c11ee4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:06:40.953444700Z",
     "start_time": "2024-03-01T07:06:40.916436100Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.2996, 0.8053],\n",
      "        [0.3061, 0.8210],\n",
      "        [0.3058, 0.8203],\n",
      "        [0.2948, 0.7939],\n",
      "        [0.2927, 0.7891],\n",
      "        [0.2990, 0.8040]], grad_fn=<MmBackward0>)\n"
     ]
    }
   ],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "# 定义自注意力模块\n",
    "class SelfAttention_v1(nn.Module):\n",
    "    def __init__(self, d_in, d_out):\n",
    "        # 调用父类构造函数\n",
    "        super().__init__()\n",
    "        # 设置输出维度\n",
    "        self.d_out = d_out\n",
    "        # 初始化查询、键和值的权重矩阵，这些矩阵是可训练的\n",
    "        self.W_query = nn.Parameter(torch.rand(d_in, d_out))\n",
    "        self.W_key = nn.Parameter(torch.rand(d_in, d_out))\n",
    "        self.W_value = nn.Parameter(torch.rand(d_in, d_out))\n",
    "\n",
    "    def forward(self, x):\n",
    "        # 使用权重矩阵将输入 x 投影到查询、键和值空间\n",
    "        keys = x @ self.W_key\n",
    "        queries = x @ self.W_query\n",
    "        values = x @ self.W_value\n",
    "        \n",
    "        # 计算注意力分数（未归一化）\n",
    "        attn_scores = queries @ keys.T  # omega\n",
    "        \n",
    "        # 使用 softmax 函数和缩放因子归一化注意力分数\n",
    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
    "\n",
    "        # 使用归一化的注意力权重和值向量计算上下文向量\n",
    "        context_vec = attn_weights @ values\n",
    "        return context_vec\n",
    "\n",
    "# 设置随机种子以确保结果的可重复性\n",
    "torch.manual_seed(123)\n",
    "# 创建 SelfAttention_v1 实例\n",
    "sa_v1 = SelfAttention_v1(d_in, d_out)\n",
    "# 使用输入数据 inputs 进行前向传播，并打印结果\n",
    "print(sa_v1(inputs))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "048e0c16-d911-4ec8-b0bc-45ceec75c081",
   "metadata": {},
   "source": [
    "- 我们可以使用 PyTorch 的线性层 `nn.Linear` 来简化上述实现，如果我们关闭偏置单元，它们就等同于矩阵乘法。\n",
    "- 使用 `nn.Linear` 而不是我们手动创建的 `nn.Parameter(torch.rand(...))` 方法的另一个重要优势是，`nn.Linear` 自带了一种优选的权重初始化方案，这有助于实现更稳定的模型训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "73f411e3-e231-464a-89fe-0a9035e5f839",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-01T07:08:05.227348900Z",
     "start_time": "2024-03-01T07:08:05.210344600Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.0739,  0.0713],\n",
      "        [-0.0748,  0.0703],\n",
      "        [-0.0749,  0.0702],\n",
      "        [-0.0760,  0.0685],\n",
      "        [-0.0763,  0.0679],\n",
      "        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)\n"
     ]
    }
   ],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "# 定义自注意力模块的第二个版本\n",
    "class SelfAttention_v2(nn.Module):\n",
    "    def __init__(self, d_in, d_out, qkv_bias=False):\n",
    "        # 调用父类构造函数\n",
    "        super().__init__()\n",
    "        # 设置输出维度\n",
    "        self.d_out = d_out\n",
    "        # 初始化查询、键和值的线性层，可以选择是否包含偏置项\n",
    "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
    "        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
    "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # 使用线性层将输入 x 投影到查询、键和值空间\n",
    "        keys = self.W_key(x)\n",
    "        queries = self.W_query(x)\n",
    "        values = self.W_value(x)\n",
    "        \n",
    "        # 计算注意力分数（未归一化）\n",
    "        attn_scores = queries @ keys.T\n",
    "        \n",
    "        # 使用 softmax 函数和缩放因子归一化注意力分数\n",
    "        # 注意这里的 dim=1，表示沿着键向量的维度进行归一化\n",
    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
    "\n",
    "        # 使用归一化的注意力权重和值向量计算上下文向量\n",
    "        context_vec = attn_weights @ values\n",
    "        return context_vec\n",
    "\n",
    "# 设置随机种子以确保结果的可重复性\n",
    "torch.manual_seed(789)\n",
    "# 创建 SelfAttention_v2 实例\n",
    "sa_v2 = SelfAttention_v2(d_in, d_out)\n",
    "# 使用输入数据 inputs 进行前向传播，并打印结果\n",
    "print(sa_v2(inputs))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "915cd8a5-a895-42c9-8b8e-06b5ae19ffce",
   "metadata": {},
   "source": [
    "- 请注意，`SelfAttention_v1` 和 `SelfAttention_v2` 会产生不同的输出，因为它们使用了不同的初始权重矩阵。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5025b37-0f2c-4a67-a7cb-1286af7026ab",
   "metadata": {},
   "source": [
    "## 3.5 遮蔽下文信息的注意力机制"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82f405de-cd86-4e72-8f3c-9ea0354946ba",
   "metadata": {},
   "source": [
    "### 3.5.1 使用因果注意力掩码"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "014f28d0-8218-48e4-8b9c-bdc5ce489218",
   "metadata": {},
   "source": [
    "- 在本节中，我们将前面的自注意力机制转换为因果自注意力机制。\n",
    "\n",
    "- 因果自注意力机制的核心目标是，确保模型对序列中某个位置的预测只依赖于前面位置的已知输出（也就是上文），而不依赖于未来位置（也就是下文）。也就是说，确保每一个词的预测只应该依赖于前面的词。\n",
    "\n",
    "- 为了实现这一点，对于每个给定的词，我们屏蔽掉未来的词（即输入文本中在当前词之后的词）。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71e91bb5-5aae-4f05-8a95-973b3f988a35",
   "metadata": {},
   "source": [
    "<img src=\"figures/masked.png\" width=\"600px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbfaec7a-68f2-4157-a4b5-2aeceed199d9",
   "metadata": {},
   "source": [
    "- 为了说明和实现因果自注意力机制，让我们使用上一节的注意力分数和权重进行操作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "1933940d-0fa5-4b17-a3ce-388e5314a1bb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],\n",
      "        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],\n",
      "        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],\n",
      "        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],\n",
      "        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],\n",
      "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
      "       grad_fn=<SoftmaxBackward0>)\n"
     ]
    }
   ],
   "source": [
    "# 使用上一节中 SelfAttention_V2 的 query 和 key 的权重矩阵\n",
    "queries = sa_v2.W_query(inputs)\n",
    "keys = sa_v2.W_key(inputs) \n",
    "attn_scores = queries @ keys.T\n",
    "# 此处的注意力权重和上一节中的一致\n",
    "attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
    "print(attn_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89020a96-b34d-41f8-9349-98c3e23fd5d6",
   "metadata": {},
   "source": [
    "- 屏蔽未来的注意力权重最简单的方法是通过 PyTorch 的 tril 函数创建一个掩码，主对角线（包括对角线本身）以下的元素设置为 1，主对角线以上的元素设置为 0："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "43f3d2e3-185b-4184-9f98-edde5e6df746",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 0., 0., 0., 0., 0.],\n",
      "        [1., 1., 0., 0., 0., 0.],\n",
      "        [1., 1., 1., 0., 0., 0.],\n",
      "        [1., 1., 1., 1., 0., 0.],\n",
      "        [1., 1., 1., 1., 1., 0.],\n",
      "        [1., 1., 1., 1., 1., 1.]])\n"
     ]
    }
   ],
   "source": [
    "# 我们创建的掩码形状应该和注意力权重矩阵的形状一致，以一一对应\n",
    "block_size = attn_scores.shape[0]\n",
    "# tril 方法会创建一个下三角矩阵\n",
    "mask_simple = torch.tril(torch.ones(block_size, block_size))\n",
    "print(mask_simple)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "efce2b08-3583-44da-b3fc-cabdd38761f6",
   "metadata": {},
   "source": [
    "- 然后，我们可以将注意力权重与这个掩码相乘，从而将对角线以上的注意力分数归零："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "9f531e2e-f4d2-4fea-a87f-4c132e48b9e7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],\n",
      "        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],\n",
      "        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],\n",
      "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
      "       grad_fn=<MulBackward0>)\n"
     ]
    }
   ],
   "source": [
    "masked_simple = attn_weights*mask_simple\n",
    "print(masked_simple)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3eb35787-cf12-4024-b66d-e7215e175500",
   "metadata": {},
   "source": [
    "- 然而，如果像上文一样，在 softmax 之后再应用掩码，它会破坏 softmax 创建的概率分布。Softmax将确保所有输出值的总和为 1，但由于我们将部分输出值置为了 0，这将导致输出值总和发生变化。\n",
    "\n",
    "- 因此，在 softmax 之后进行掩码处理将需要重新对输出进行归一化，使其总和再次为 1。但是，这使得过程变得复杂，并可能导致意想不到的效果。\n",
    "\n",
    "- 为了确保输出值的总和为 1，我们可以将权重矩阵进行如下的归一化："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "6d392083-fd81-4f70-9bdf-8db985e673d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],\n",
      "        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],\n",
      "        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],\n",
      "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
      "       grad_fn=<DivBackward0>)\n"
     ]
    }
   ],
   "source": [
    "# dim = 1 表示按行求和\n",
    "row_sums = masked_simple.sum(dim=1, keepdim=True)\n",
    "masked_simple_norm = masked_simple / row_sums\n",
    "print(masked_simple_norm)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "512e7cf4-dc0e-4cec-948e-c7a3c4eb6877",
   "metadata": {},
   "source": [
    "- 尽管我们现在在技术上已经完成了因果注意力机制，但还有一些实现上述相同效果的更有效的方法。\n",
    "\n",
    "- 例如，我们可以在未归一化的注意力分数进入 softmax 函数之前，用负无穷大掩盖对角线以上的部分，而不是将对角线以上的注意力权重归零并重新归一化结果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "a2be2f43-9cf0-44f6-8d8b-68ef2fb3cc39",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],\n",
      "        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],\n",
      "        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],\n",
      "        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],\n",
      "        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],\n",
      "        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],\n",
      "       grad_fn=<MaskedFillBackward0>)\n"
     ]
    }
   ],
   "source": [
    "# 也就是说，通过将掩码从 0 修改为 -inf，可以将遮蔽操作提到 softmax 之前\n",
    "mask = torch.triu(torch.ones(block_size, block_size), diagonal=1)\n",
    "masked = attn_scores.masked_fill(mask.bool(), -torch.inf)\n",
    "print(masked)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91d5f803-d735-4543-b9da-00ac10fb9c50",
   "metadata": {},
   "source": [
    "- 正如我们所见，接下来我们再让注意力矩阵通过 softmax，就可以将每行之和都重新变回 1："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "b1cd6d7f-16f2-43c1-915e-0824f1a4bc52",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],\n",
      "        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],\n",
      "        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],\n",
      "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
      "       grad_fn=<SoftmaxBackward0>)\n"
     ]
    }
   ],
   "source": [
    "attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)\n",
    "print(attn_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7636fc5f-6bc6-461e-ac6a-99ec8e3c0912",
   "metadata": {},
   "source": [
    "### 3.5.2 通过 dropout 来实现额外注意力权重的掩码"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec3dc7ee-6539-4fab-804a-8f31a890c85a",
   "metadata": {},
   "source": [
    "- 此外，我们还可以在训练过程中应用 dropout 来减少过拟合。\n",
    "\n",
    "- dropout 可以应用于例如下列例子的多个地方：\n",
    "    - 计算注意力权重后；\n",
    "    - 将注意力权重与值向量相乘后。\n",
    "\n",
    "- 在这里，我们将在计算注意力权重后应用 dropout 掩码，因为这种情况更为常见。\n",
    "\n",
    "- 此外，在这个特定的例子中，我们使用了 50% 的 dropout 率，这意味着随机屏蔽一半的注意力权重。（当我们稍后训练 GPT 模型时，我们将使用较低的 dropout 率，例如 0.1 或 0.2。）"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee799cf6-6175-45f2-827e-c174afedb722",
   "metadata": {},
   "source": [
    "<img src=\"figures/dropout.png\" width=\"500px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a575458-a6da-4e54-8688-83e155f2de06",
   "metadata": {},
   "source": [
    "- 注意，如果我们应用 0.5 的 dropout 率，那么未被屏蔽的值将按照 1/0.5 = 2 的比例进行相应缩放。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "0de578db-8289-41d6-b377-ef645751e33f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2., 2., 2., 2., 2., 2.],\n",
      "        [0., 2., 0., 0., 0., 0.],\n",
      "        [0., 0., 2., 0., 2., 0.],\n",
      "        [2., 2., 0., 0., 0., 2.],\n",
      "        [2., 0., 0., 0., 0., 2.],\n",
      "        [0., 2., 0., 0., 0., 0.]])\n"
     ]
    }
   ],
   "source": [
    "# 随便设置一个随机数种子\n",
    "torch.manual_seed(123)\n",
    "dropout = torch.nn.Dropout(0.5) # 设置 50% 的 Dropout 比例\n",
    "example = torch.ones(6, 6) # 创建一个全 1 矩阵作为示例\n",
    "\n",
    "print(dropout(example))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "b16c5edb-942b-458c-8e95-25e4e355381e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.3843, 0.3293, 0.3303, 0.3100, 0.3442, 0.3019],\n",
      "        [0.0000, 0.3318, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.3325, 0.0000, 0.3328, 0.0000],\n",
      "        [0.3738, 0.3334, 0.0000, 0.0000, 0.0000, 0.3128],\n",
      "        [0.3661, 0.0000, 0.0000, 0.0000, 0.0000, 0.3169],\n",
      "        [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],\n",
      "       grad_fn=<MulBackward0>)\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "# 对注意力权重进行 dropout\n",
    "print(dropout(attn_weights))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdc14639-5f0f-4840-aa9d-8eb36ea90fb7",
   "metadata": {},
   "source": [
    "## 3.5.3 实现一个因果自注意类"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09c41d29-1933-43dc-ada6-2dbb56287204",
   "metadata": {},
   "source": [
    "- 现在，我们已经准备好实现一个包含 dropout 的因果自注意力类。\n",
    "\n",
    "- 我们还需要实现处理由多个输入组成的一批样本的代码，以便我们的 CausalAttention 类支持我们在第2章中实现的 dataloader 产生的批量输出。\n",
    "\n",
    "- 为了简化，为了模拟这样的批量输入，我们复制输入文本示例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "977a5fa7-a9d5-4e2e-8a32-8e0331ccfe28",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 6, 3])\n"
     ]
    }
   ],
   "source": [
    "batch = torch.stack((inputs, inputs), dim=0)\n",
    "print(batch.shape) # 2个输入，每个输入有 6个 token，每个 token 的维度为 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "60d8c2eb-2d8e-4d2c-99bc-9eef8cc53ca0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[-0.0844,  0.0414],\n",
      "         [-0.2264, -0.0039],\n",
      "         [-0.4163, -0.0564],\n",
      "         [-0.5014, -0.1011],\n",
      "         [-0.7754, -0.1867],\n",
      "         [-1.1632, -0.3303]],\n",
      "\n",
      "        [[-0.0844,  0.0414],\n",
      "         [-0.2264, -0.0039],\n",
      "         [-0.4163, -0.0564],\n",
      "         [-0.5014, -0.1011],\n",
      "         [-0.7754, -0.1867],\n",
      "         [-1.1632, -0.3303]]], grad_fn=<UnsafeViewBackward0>)\n",
      "context_vecs.shape: torch.Size([2, 6, 2])\n"
     ]
    }
   ],
   "source": [
    "# 定义一个带 dropout 的因果自注意力层\n",
    "class CausalAttention(nn.Module):\n",
    "\n",
    "    def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):\n",
    "        '''\n",
    "        构造函数，输入参数如下：\n",
    "        d_in: 输入的维度\n",
    "        d_out: 输出的维度\n",
    "        block_size: 注意力权重矩阵的大小\n",
    "        dropout: dropout 比例\n",
    "        qkv_bias: 是否对 query、key 和 value 加偏置\n",
    "        '''\n",
    "        super().__init__()\n",
    "        self.d_out = d_out\n",
    "        # 根据前文，每一个权重矩阵都是 d_in x d_out 的线性层\n",
    "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
    "        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
    "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
    "        # 一个 dropout 层\n",
    "        self.dropout = nn.Dropout(dropout) \n",
    "        # 一个掩码矩阵，下三角为 1，其余为 0\n",
    "        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
    "\n",
    "    def forward(self, x):\n",
    "        '''\n",
    "        前向传播函数，输入参数为 x，维度为 b x num_tokens x d_in，输出维度为 b x num_tokens x d_out\n",
    "        '''\n",
    "        b, num_tokens, d_in = x.shape\n",
    "        keys = self.W_key(x)\n",
    "        queries = self.W_query(x)\n",
    "        values = self.W_value(x)\n",
    "        # transpose 是为了实现矩阵乘法\n",
    "        attn_scores = queries @ keys.transpose(1, 2)\n",
    "        # 即上文说过的，将掩码从 0 修改为 -inf，再进行遮蔽操作\n",
    "        attn_scores.masked_fill_(  # New, _ ops are in-place\n",
    "            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)\n",
    "        # 经过 softmax \n",
    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
    "        # 进行 dropout\n",
    "        attn_weights = self.dropout(attn_weights) # New\n",
    "        # 得到最后结果\n",
    "        context_vec = attn_weights @ values\n",
    "        return context_vec\n",
    "\n",
    "# 实验一下\n",
    "torch.manual_seed(123)\n",
    "\n",
    "block_size = batch.shape[1]\n",
    "ca = CausalAttention(d_in, d_out, block_size, 0.0)\n",
    "\n",
    "context_vecs = ca(batch)\n",
    "\n",
    "print(context_vecs)\n",
    "print(\"context_vecs.shape:\", context_vecs.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4333d12-17e4-4bb5-9d83-54b3a32618cd",
   "metadata": {},
   "source": [
    "- 注意 dropout 只在训练阶段被使用，在推理阶段是不使用的"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8bef90f-cfd4-4289-b0e8-6a00dc9be44c",
   "metadata": {},
   "source": [
    "## 3.6 将单头注意力扩展到多头"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11697757-9198-4a1c-9cee-f450d8bbd3b9",
   "metadata": {},
   "source": [
    "### 3.6.1 直接将多个单头注意力层堆积起来"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70766faf-cd53-41d9-8a17-f1b229756a5a",
   "metadata": {},
   "source": [
    "- 下图是之前提到过的自注意力的总结（为了简便起见，因果注意力掩码和 dropout 并没有展示） \n",
    "\n",
    "- 也被称之为单头注意力:\n",
    "\n",
    "<img src=\"figures/single-head.png\" width=\"600px\">\n",
    "\n",
    "- 我们可以简单地将多个单头注意力层堆积在一起实现多头注意力层:\n",
    "\n",
    "<img src=\"figures/multi-head.png\" width=\"600px\">\n",
    "\n",
    "- 多头注意力机制的主要思想是使用不同的、已学习的权重矩阵，多次（并行）运行注意力机制。这使得模型能够在不同位置的不同表示子空间中联合关注信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "b9a66e11-7105-4bb4-be84-041f1a1f3bd2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[-0.0844,  0.0414,  0.0766,  0.0171],\n",
      "         [-0.2264, -0.0039,  0.2143,  0.1185],\n",
      "         [-0.4163, -0.0564,  0.3878,  0.2453],\n",
      "         [-0.5014, -0.1011,  0.4992,  0.3401],\n",
      "         [-0.7754, -0.1867,  0.7387,  0.4868],\n",
      "         [-1.1632, -0.3303,  1.1224,  0.8460]],\n",
      "\n",
      "        [[-0.0844,  0.0414,  0.0766,  0.0171],\n",
      "         [-0.2264, -0.0039,  0.2143,  0.1185],\n",
      "         [-0.4163, -0.0564,  0.3878,  0.2453],\n",
      "         [-0.5014, -0.1011,  0.4992,  0.3401],\n",
      "         [-0.7754, -0.1867,  0.7387,  0.4868],\n",
      "         [-1.1632, -0.3303,  1.1224,  0.8460]]], grad_fn=<CatBackward0>)\n",
      "context_vecs.shape: torch.Size([2, 6, 4])\n"
     ]
    }
   ],
   "source": [
    "# 定义一个多头注意力层\n",
    "class MultiHeadAttentionWrapper(nn.Module):\n",
    "\n",
    "    def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
    "        super().__init__()\n",
    "            # 将 num_heads 个单头注意力层组合在一起来实现多头\n",
    "        self.heads = nn.ModuleList(\n",
    "            [CausalAttention(d_in, d_out, block_size, dropout, qkv_bias) \n",
    "             for _ in range(num_heads)]\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        # 前向计算时将多个头的输出拼接在一起\n",
    "        return torch.cat([head(x) for head in self.heads], dim=-1)\n",
    "\n",
    "\n",
    "# 实验一下\n",
    "torch.manual_seed(123)\n",
    "\n",
    "block_size = batch.shape[1] # token 数量\n",
    "d_in, d_out = 3, 2\n",
    "mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads=2)\n",
    "\n",
    "context_vecs = mha(batch)\n",
    "\n",
    "print(context_vecs)\n",
    "print(\"context_vecs.shape:\", context_vecs.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "193d3d2b-2578-40ba-b791-ea2d49328e48",
   "metadata": {},
   "source": [
    "- 在上面的实现中，嵌入维度是4，因为我们为 key、query、value 都设置了 d_out=2 作为嵌入维度。由于我们有2个注意力头，因此输出嵌入维度为 2*2=4。\n",
    "\n",
    "- 如果我们想要输出维度为2，就像早期的单头注意力那样，我们可以将投影维度 d_out 更改为1："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "dc9a4375-068b-4b2a-aabb-a29347ca5ecd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[-9.1476e-02,  3.4164e-02],\n",
      "         [-2.6796e-01, -1.3427e-03],\n",
      "         [-4.8421e-01, -4.8909e-02],\n",
      "         [-6.4808e-01, -1.0625e-01],\n",
      "         [-8.8380e-01, -1.7140e-01],\n",
      "         [-1.4744e+00, -3.4327e-01]],\n",
      "\n",
      "        [[-9.1476e-02,  3.4164e-02],\n",
      "         [-2.6796e-01, -1.3427e-03],\n",
      "         [-4.8421e-01, -4.8909e-02],\n",
      "         [-6.4808e-01, -1.0625e-01],\n",
      "         [-8.8380e-01, -1.7140e-01],\n",
      "         [-1.4744e+00, -3.4327e-01]]], grad_fn=<CatBackward0>)\n",
      "context_vecs.shape: torch.Size([2, 6, 2])\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "\n",
    "d_out = 1\n",
    "mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads=2)\n",
    "\n",
    "context_vecs = mha(batch)\n",
    "\n",
    "print(context_vecs)\n",
    "print(\"context_vecs.shape:\", context_vecs.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6836b5da-ef82-4b4c-bda1-72a462e48d4e",
   "metadata": {},
   "source": [
    "### 3.6.2 通过权重分割实现多头注意力"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4b48d0d-71ba-4fa0-b714-ca80cabcb6f7",
   "metadata": {},
   "source": [
    "- 尽管上述是多头注意力最直观且功能完整的实现（将早期的单头注意力 CausalAttention 实现封装在内），但我们也可以编写一个名为MultiHeadAttention 的独立类来实现相同的功能。\n",
    "\n",
    "- 对于这个独立的 MultiHeadAttention 类，我们不会将单个注意力头连接在一起。相反，我们创建单个的 W_query、W_key 和 W_value 权重矩阵，然后将它们拆分为每个注意力头的独立矩阵："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "110b0188-6e9e-4e56-a988-10523c6c8538",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[0.3190, 0.4858],\n",
      "         [0.2943, 0.3897],\n",
      "         [0.2856, 0.3593],\n",
      "         [0.2693, 0.3873],\n",
      "         [0.2639, 0.3928],\n",
      "         [0.2575, 0.4028]],\n",
      "\n",
      "        [[0.3190, 0.4858],\n",
      "         [0.2943, 0.3897],\n",
      "         [0.2856, 0.3593],\n",
      "         [0.2693, 0.3873],\n",
      "         [0.2639, 0.3928],\n",
      "         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)\n",
      "context_vecs.shape: torch.Size([2, 6, 2])\n"
     ]
    }
   ],
   "source": [
    "class MultiHeadAttention(nn.Module):\n",
    "    def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
    "        super().__init__()\n",
    "        # 因为要对权重矩阵按注意力头数进行拆分，所有输出维度必须是头数的整数倍\n",
    "        assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
    "\n",
    "        self.d_out = d_out\n",
    "        self.num_heads = num_heads\n",
    "        # head_dim 就是拆分之后每个头应该输出的维度\n",
    "        self.head_dim = d_out // num_heads \n",
    "\n",
    "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
    "        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
    "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
    "        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",
    "\n",
    "    def forward(self, x):\n",
    "        b, num_tokens, d_in = x.shape\n",
    "\n",
    "        # 形状为 (b, num_tokens, d_out)\n",
    "        keys = self.W_key(x)\n",
    "        queries = self.W_query(x)\n",
    "        values = self.W_value(x)\n",
    "\n",
    "        # 我们可以通过增加一个 num_heads 的维度来将矩阵分割到每个头\n",
    "        # 维度变化: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)\n",
    "        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) \n",
    "        values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n",
    "        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
    "\n",
    "        # 转置一下: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)\n",
    "        keys = keys.transpose(1, 2)\n",
    "        queries = queries.transpose(1, 2)\n",
    "        values = values.transpose(1, 2)\n",
    "\n",
    "        # 计算注意力权重\n",
    "        # 基于矩阵乘法，简单地实现各个头的并行计算\n",
    "        attn_scores = queries @ keys.transpose(2, 3) \n",
    "        # 一般来说我们会将掩码矩阵转化为 bool 值并基于序列的长度进行截断\n",
    "        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
    "        # 需要将掩码矩阵 unsqueeze 两次，也就是增加两个维度，才能让掩码矩阵的维度和注意力权重对应上\n",
    "        mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)\n",
    "        # 使用掩码矩阵来进行遮蔽\n",
    "        attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)\n",
    "        \n",
    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
    "        attn_weights = self.dropout(attn_weights)\n",
    "\n",
    "        # 形状: (b, num_tokens, num_heads, head_dim)\n",
    "        context_vec = (attn_weights @ values).transpose(1, 2) \n",
    "        \n",
    "        # 将多个头的输出重新组合回去 self.d_out = self.num_heads * self.head_dim\n",
    "        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)\n",
    "        context_vec = self.out_proj(context_vec) # optional projection\n",
    "\n",
    "        return context_vec\n",
    "\n",
    "# 试验一下\n",
    "torch.manual_seed(123)\n",
    "\n",
    "batch_size, block_size, d_in = batch.shape\n",
    "d_out = 2\n",
    "mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads=2)\n",
    "\n",
    "context_vecs = mha(batch)\n",
    "\n",
    "print(context_vecs)\n",
    "print(\"context_vecs.shape:\", context_vecs.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d334dfb5-2b6c-4c33-82d5-b4e9db5867bb",
   "metadata": {},
   "source": [
    "- 请注意，以上内容实际上是 MultiHeadAttentionWrapper 的一个更高效的改写版。\n",
    "- 由于随机权重初始化的差异，最终的输出结果看起来有些不同，但两者都是完全可以使用的实现，将在后续章节中实现的GPT类中使用。\n",
    "- 此外，我们在上述 MultiHeadAttention 类中添加了一个线性投影层(self.out_proj)。这只是一个不会改变维度的线性变换。在LLM实现中使用这样的投影层是一种标准惯例，但并非严格必要（最近的研究表明，它可以被移除而不会影响建模性能；见本章末尾的进一步阅读部分）"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b0ed78c-e8ac-4f8f-a479-a98242ae8f65",
   "metadata": {},
   "source": [
    "- 如果你对更复杂、高效的多头注意力实现感兴趣，你可以考虑使用 PyTorch 的 [`torch.nn.MultiheadAttention`](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) 类。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "363701ad-2022-46c8-9972-390d2a2b9911",
   "metadata": {},
   "source": [
    "- 上述实现可能看上去有一点复杂，让我们来看一下，当运行 `attn_scores = queries @ keys.transpose(2, 3)` 时会发生什么："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "e8cfc1ae-78ab-4faa-bc73-98bd054806c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[1.3208, 1.1631, 1.2879],\n",
      "          [1.1631, 2.2150, 1.8424],\n",
      "          [1.2879, 1.8424, 2.0402]],\n",
      "\n",
      "         [[0.4391, 0.7003, 0.5903],\n",
      "          [0.7003, 1.3737, 1.0620],\n",
      "          [0.5903, 1.0620, 0.9912]]]])\n"
     ]
    }
   ],
   "source": [
    "# (b, num_heads, num_tokens, head_dim) = (1, 2, 3, 4)\n",
    "a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],\n",
    "                    [0.8993, 0.0390, 0.9268, 0.7388],\n",
    "                    [0.7179, 0.7058, 0.9156, 0.4340]],\n",
    "\n",
    "                   [[0.0772, 0.3565, 0.1479, 0.5331],\n",
    "                    [0.4066, 0.2318, 0.4545, 0.9737],\n",
    "                    [0.4606, 0.5159, 0.4220, 0.5786]]]])\n",
    "\n",
    "print(a @ a.transpose(2, 3))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0587b946-c8f2-4888-adbf-5a5032fbfd7b",
   "metadata": {},
   "source": [
    "- 在这种情况下，PyTorch 中的矩阵乘法实现将处理 4 维输入张量，以便在最后的两个维度（num_tokens，head_dim）之间进行矩阵乘法，然后针对各个头重复进行。\n",
    "\n",
    "- 例如，上述内容成为了一种单独计算每个头的更紧凑的矩阵乘法："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "053760f1-1a02-42f0-b3bf-3d939e407039",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "First head:\n",
      " tensor([[1.3208, 1.1631, 1.2879],\n",
      "        [1.1631, 2.2150, 1.8424],\n",
      "        [1.2879, 1.8424, 2.0402]])\n",
      "\n",
      "Second head:\n",
      " tensor([[0.4391, 0.7003, 0.5903],\n",
      "        [0.7003, 1.3737, 1.0620],\n",
      "        [0.5903, 1.0620, 0.9912]])\n"
     ]
    }
   ],
   "source": [
    "first_head = a[0, 0, :, :]\n",
    "first_res = first_head @ first_head.T\n",
    "print(\"First head:\\n\", first_res)\n",
    "\n",
    "second_head = a[0, 1, :, :]\n",
    "second_res = second_head @ second_head.T\n",
    "print(\"\\nSecond head:\\n\", second_res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "08c2a3fd-e674-4d69-9ef4-ea94b788e937",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2360064"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "block_size = 1024\n",
    "d_in, d_out = 768, 768\n",
    "num_heads = 12\n",
    "\n",
    "mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads)\n",
    "\n",
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "count_parameters(mha)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dec671bf-7938-4304-ad1e-75d9920e7f43",
   "metadata": {},
   "source": [
    "# 总结与收获"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa3e4113-ffca-432c-b3ec-7a50bd15da25",
   "metadata": {},
   "source": [
    "- 你可以查看 [./multihead-attention.ipynb](./multihead-attention.ipynb) 代码 Notebook，这是 DataLoader（第2章）的简洁版本，加上我们在本章中实现的多头注意力类，我们将在后续章节中训练 GPT 模型时使用它。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
